from collections import OrderedDict

import torch
from apex.optimizers import FusedAdam
from apex.contrib.sparsity import ASP


def build_model(args):
    od = OrderedDict()
    for i in range(args.num_layers):
        if i == 0:
            od["linear_layer_%d" % (i + 1)] = torch.nn.Linear(
                args.input_features, args.hidden_features
            )
            od["layer_norm_%d" % (i + 1)] = torch.nn.LayerNorm(
                [args.batch_size, args.hidden_features]
            )
        elif i == args.num_layers - 1:
            od["linear_layer_%d" % (i + 1)] = torch.nn.Linear(
                args.hidden_features, args.output_features
            )
            od["layer_norm_%d" % (i + 1)] = torch.nn.LayerNorm(
                [args.batch_size, args.output_features]
            )
        else:
            od["linear_layer_%d" % (i + 1)] = torch.nn.Linear(
                args.hidden_features, args.hidden_features
            )
            od["layer_norm_%d" % (i + 1)] = torch.nn.LayerNorm(
                [args.batch_size, args.hidden_features]
            )
    return torch.nn.Sequential(od)


def train_step(args, model, optimizer, input_batch, target_batch, step):
    predicted_target = model(input_batch)
    loss = ((predicted_target - target_batch) ** 2).sum()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    step = step + 1
    # print("Step %d :: loss=%e" % (step, loss.item()))
    return step


def train_loop(args, model, optimizer, step, num_steps):
    for i in range(num_steps):
        input_batch = torch.randn([args.batch_size, args.input_features]).cuda()
        target_batch = torch.randn([args.batch_size, args.output_features]).cuda()
        step = train_step(args, model, optimizer, input_batch, target_batch, step)
    return step


def main(step, args, model_state_dict, optimizer_state_dict):
    #
    # PART2
    #

    model = build_model(args).cuda()
    one_ll = next(model.children()).weight
    optimizer = FusedAdam(model.parameters())
    ASP.init_model_for_pruning(
        model,
        args.pattern,
        verbosity=args.verbosity,
        whitelist=args.whitelist,
        allow_recompute_mask=args.allow_recompute_mask,
    )
    ASP.init_optimizer_for_pruning(optimizer)

    torch.manual_seed(args.seed2)
    model.load_state_dict(model_state_dict)
    optimizer.load_state_dict(optimizer_state_dict)

    print("Model sparsity is %s" % ("enabled" if ASP.is_sparsity_enabled() else "disabled"))

    # train for a few steps with sparse weights
    print("SPARSE :: ", one_ll)
    step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2)


if __name__ == "__main__":
    checkpoint = torch.load("part1.chkp")

    class Args:
        verbosity = checkpoint["verbosity"]
        seed = 4873
        seed2 = checkpoint["seed2"]
        pattern = checkpoint["pattern"]
        whitelist = checkpoint["whitelist"]
        allow_recompute_mask = checkpoint["allow_recompute_mask"]
        batch_size = 32
        input_features = 8
        output_features = 8
        hidden_features = 32
        num_layers = 4
        num_dense_steps = 2000
        num_sparse_steps = 3000
        num_sparse_steps_2 = 1000
        checkpoint_path = "part1.chkp"

    args = Args()

    main(
        checkpoint["step"],
        args,
        checkpoint["model_state_dict"],
        checkpoint["optimizer_state_dict"],
    )
